{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a7d19182",
   "metadata": {},
   "source": [
    "# Mixed Dirichlet\n",
    "\n",
    "\n",
    "A mixed Dirichlet random variable $Y$ takes on values in the probability simplex $\\Delta_{K-1}$, an assignment $Y=y$ has probability density given by \n",
    "\n",
    "\\begin{align}\n",
    "P_{Y}(y|\\alpha, w) &= \\sum_{f} \\mathrm{Gibbs}(f|w) \\times \\mathrm{Dirichlet}(y|\\alpha \\odot f)\n",
    "\\end{align}\n",
    "\n",
    "where $w \\in \\mathbb R^K$, $\\alpha \\in \\mathbb R^K_{>0}$, $f$ is one of the non-empty faces of the simplex,  by $\\alpha \\odot f$ we mean the sub-vector of $\\alpha$ whose coordinates are associated with the vertices in $f$. \n",
    "\n",
    "The distribution over proper faces has probability mass function:\n",
    "\\begin{align}\n",
    "\\mathrm{Gibbs}(f|w) = \\frac{\\exp(w^\\top \\phi(f))}{\\sum_{f'} \\exp(w^\\top \\phi(f'))}\n",
    "\\end{align}\n",
    "where $\\phi(f) \\in \\mathbb {-1, 1}^K$ is such that $\\phi_k(f) = 1$ if the vertex $\\mathbf e_k$ is in the face, and $-1$, otherwise. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ec2ce34",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.distributions as td\n",
    "from mixed import MixedDirichlet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1bb53fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "MixedDirichlet(concentration=torch.ones(3), scores=torch.zeros(3)).sample([10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa7a510a",
   "metadata": {},
   "outputs": [],
   "source": [
    "MixedDirichlet(concentration=torch.ones(3)/100, scores=torch.zeros(3)).sample([10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76fd4565",
   "metadata": {},
   "source": [
    "## Understanding the parts \n",
    "\n",
    "An efficient GPU-friendly implementation of Mixed Dirichlet distributions takes two auxiliary distributions, namely, a GPU-friendly discrete exponential family over proper faces and a GPU-friendly Dirichlet distribution (for which we can batch distributions of varying dimensionality, from $1$ to $K$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "506b1a74",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bitvector import NonEmptyBitVector"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49be6f52",
   "metadata": {},
   "source": [
    "We can use bit-vectors to encode each of the faces of the simplex. A proper face has $1$ to $K$ vertices, so we use a $K$-dimensional bit-vector $f$. If $f_k = 1$ the vertex $\\mathbf e_k$ is in the face.\n",
    "\n",
    "A distribution over the proper faces can be obtained by scoring each of the vertices independently, i.e., \n",
    "\n",
    "$P_F(f) \\propto \\exp(\\sum_{k=1}^K (-1)^{1-f_k}w_k)$\n",
    "\n",
    "The normaliser of this expression sums over the set of proper faces, thus it excludes the 0 bit-vector $\\mathbf 0$.\n",
    "\n",
    "The class `bitvector.NonEmptyBitVector` implements a GPU-friendly version of the necessary procedure, which is based on a directed acyclic graph (DAG) of size $\\mathcal O(K)$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e39946b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "NonEmptyBitVector(torch.zeros(3)).enumerate_support()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46376cb1",
   "metadata": {},
   "source": [
    "Here a batch of two such distributions, each with a different parameter vector:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b445e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "NonEmptyBitVector(torch.stack([torch.zeros(3), torch.ones(3)], 0)).sample([1000]).mean(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8562c35",
   "metadata": {},
   "source": [
    "If $A_k \\sim \\mathrm{Gamma}(\\alpha_k, 1)$, and $T = \\sum_{k=1}^K A_k$ then\n",
    "\\begin{align}\n",
    "    \\left(\\frac{A_1}{T}, \\ldots, \\frac{A_K}{T} \\right)^\\top & \\sim \\mathrm{Dirichlet}(\\alpha)\n",
    "\\end{align}\n",
    "\n",
    "\n",
    "We can use this fact to implement a Dirichlet distribution parametrized by a shared vector of $K$ concentration parameters and a *mask* which identifies which face of the simplex the Dirichlet supports.\n",
    "\n",
    "The class `dirichlet.MaskedDirichlet` implements such a GPU-friendly distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20af5c0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dirichlet import MaskedDirichlet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21881f69",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = NonEmptyBitVector(torch.zeros(3)).sample([2])\n",
    "f"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5294765a",
   "metadata": {},
   "source": [
    "In the following example, one Dirichlet is defined over the face that contains $\\mathbf e_2$ and $\\mathbf e_3$, the other is defined over the entire simplex."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5b9d8d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "MaskedDirichlet(\n",
    "    mask=torch.stack([torch.tensor([False, True, True]), torch.tensor([True, True, True])], 0), \n",
    "    concentration=torch.stack([torch.ones(3), torch.ones(3)/100], 0)\n",
    ").sample([2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "380473c2",
   "metadata": {},
   "source": [
    "# Uniform F and Uniform Y|f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3494a757",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e5082ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_marginals(samples, bins=100):\n",
    "    D = samples.shape[-1]\n",
    "    fig, ax = plt.subplots(D, 1, figsize=(4, 2*D), sharex=True)\n",
    "    for d in range(D):\n",
    "        _ = ax[d].hist(samples[...,d].flatten().numpy(), bins=bins, density=True)\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "501a1b29",
   "metadata": {},
   "outputs": [],
   "source": [
    "p3d = MixedDirichlet(concentration=torch.ones(3), scores=torch.zeros(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b7273ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "p3d.sample([10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faef453d",
   "metadata": {},
   "outputs": [],
   "source": [
    "p3d.entropy(), p3d.cross_entropy(p3d), td.kl_divergence(p3d, p3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8ca6a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_marginals(p3d.sample([1000]), bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a6d07e",
   "metadata": {},
   "outputs": [],
   "source": [
    "_p = p3d.expand([2, 1])\n",
    "_p.sample().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26782fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "_p.entropy(), _p.cross_entropy(_p), td.kl_divergence(_p, _p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d857fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "_p.faces.cross_entropy(_p.faces).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aa64075",
   "metadata": {},
   "source": [
    "# Max-Ent F and Uniform Y|f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b5a6de6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bitvector import MaxEntropyFaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82dbb23d",
   "metadata": {},
   "outputs": [],
   "source": [
    "pm3d = MixedDirichlet(concentration=torch.ones(3), pmf_n=MaxEntropyFaces.pmf_n(3, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc89f0d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = plot_marginals(pm3d.sample(torch.Size([1000])), bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2def7413",
   "metadata": {},
   "outputs": [],
   "source": [
    "pm3d.entropy(), pm3d.cross_entropy(pm3d), td.kl_divergence(pm3d, pm3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f2f1d02",
   "metadata": {},
   "outputs": [],
   "source": [
    "_pm = pm3d.expand([2, 1])\n",
    "_pm.sample().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b54f549",
   "metadata": {},
   "outputs": [],
   "source": [
    "_pm.entropy(), _pm.cross_entropy(_pm), td.kl_divergence(_pm, _pm)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ab922e7",
   "metadata": {},
   "source": [
    "# VI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c12a454f",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = MixedDirichlet(concentration=torch.ones(5), pmf_n=MaxEntropyFaces.pmf_n(5, 1))\n",
    "q = MixedDirichlet(concentration=torch.ones(5)/10, scores=torch.zeros(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55df04a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.batch_shape, p.event_shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e9caef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.sample(torch.Size([10]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "432f9258",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = p.faces.enumerate_support()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3569bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.faces.log_prob(f).exp(), f.sum(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91578ff4",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.cross_entropy(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cf2227c",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.Y(f).cross_entropy(q.Y(f))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f53f98",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.Y(f).entropy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a379255",
   "metadata": {},
   "outputs": [],
   "source": [
    "p.cross_entropy(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0a6de00",
   "metadata": {},
   "outputs": [],
   "source": [
    "td.kl_divergence(p, q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa290656",
   "metadata": {},
   "outputs": [],
   "source": [
    "td.kl_divergence(p.faces, q.faces)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "885eece7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "226b19c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
